# Copyright 2017 The TensorFlow Lattice Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lattice layers library for TensorFlow Lattice.

Lattice is an interpolated lookup table (LUT), part of TensorFlow Lattice
models.

This modules provides functions used when building models, as opposed to the
basic operators exported by lattice_ops.py
"""
from tensorflow_lattice.python.lib import regularizers
from tensorflow_lattice.python.lib import tools
from tensorflow_lattice.python.ops import lattice_ops
from tensorflow_lattice.python.ops.gen_monotone_lattice import monotone_lattice

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope

_VALID_INTERPOLATION_TYPES = ['hypercube', 'simplex']


def lattice_param_as_linear(lattice_sizes, output_dim, linear_weights=1.0):
  """Returns lattice parameter that represents a normalized linear function.

  For simplicity, let's assume output_dim == 1. This function returns a lattice
  parameter so that

    lattice_param' phi(x) = 1 / (len(lattice_sizes))
      * sum_k (linear_weights[k]/(lattice_sizes[k] - 1)) * x[k] - 0.5.

  where phi(x) is the lattice interpolation weight.
  The normalization in the weights and bias (-0.5) are introduced to make the
  output ranges [-0.5, 0.5] when all linear_weights are 1.0.

  The returned lattice_param can be used to initialize a lattice layer as a
  linear function.

  Args:
    lattice_sizes: (list of ints) A list of lattice sizes of each dimension.
    output_dim: (int) number of outputs.
    linear_weights: (float, list of floats, list of list of floats) linear
      function's weight terms. linear_weights[k][n] == kth output's nth weight.
      If float, then all the weights uses one value as
      [[linear_weights] * len(lattice_sizes)] * output_dim.
      If list of floats, then the len(linear_weights) == len(lattice_sizes) is
      expected, and the weights are [linear_weights] * output_dim, i.e., all
      output_dimension will get same linear_weights.
  Returns:
    List of list of floats with size (output_dim, number_of_lattice_param).
  Raises:
    ValueError: * Any element in lattice_sizes is less than 2.
      * lattice_sizes is empty.
      * If linear_weights is not supported type, or shape of linear_weights are
        not the desired values .
  """
  if not lattice_sizes:
    raise ValueError('lattice_sizes should not be empty')
  for lattice_size in lattice_sizes:
    if lattice_size < 2:
      raise ValueError('All elements in lattice_sizes are expected to greater '
                       'than equal to 2, but got %s' % lattice_sizes)

  lattice_rank = len(lattice_sizes)
  linear_weight_matrix = None
  if isinstance(linear_weights, float):
    linear_weight_matrix = [[linear_weights] * lattice_rank] * output_dim
  elif isinstance(linear_weights, list):
    # Branching using the first element in linear_weights. linear_weights[0]
    # should exist, since lattice_sizes is not empty.
    if isinstance(linear_weights[0], float):
      if len(linear_weights) != lattice_rank:
        raise ValueError(
            'A number of elements in linear_weights (%d) != lattice rank (%d)' %
            (len(linear_weights), lattice_rank))
      # Repeating same weights for all output_dim.
      linear_weight_matrix = [linear_weights] * output_dim
    elif isinstance(linear_weights[0], list):
      # 2d matrix case.
      if len(linear_weights) != output_dim:
        raise ValueError(
            'A number of lists in linear_weights (%d) != output_dim (%d)' %
            (len(linear_weights), output_dim))
      for linear_weight in linear_weights:
        if len(linear_weight) != lattice_rank:
          raise ValueError(
              'linear_weights contain more than one list whose length != '
              'lattice rank(%d)' % lattice_rank)
      linear_weight_matrix = linear_weights
    else:
      raise ValueError(
          'Only list of float or list of list of floats are supported')
  else:
    raise ValueError(
        'Only float or list of float or list of list of floats are supported.')

  # Create lattice structure to enumerate (index, lattice_dim) pairs.
  lattice_structure = tools.LatticeStructure(lattice_sizes)

  # Normalize linear_weight_matrix.
  lattice_parameters = []
  for linear_weight_per_output in linear_weight_matrix:
    lattice_parameter = [-0.5] * lattice_structure.num_vertices
    for (idx, vertex) in tools.lattice_indices_generator(lattice_structure):
      for dim in range(lattice_rank):
        lattice_parameter[idx] += (linear_weight_per_output[dim] * float(
            vertex[dim]) / float(lattice_rank * (lattice_sizes[dim] - 1)))
    lattice_parameters.append(lattice_parameter)

  return lattice_parameters


def lattice_layer(input_tensor,
                  lattice_sizes,
                  is_monotone=None,
                  output_dim=1,
                  interpolation_type='hypercube',
                  lattice_initializer=None,
                  l1_reg=None,
                  l2_reg=None,
                  l1_torsion_reg=None,
                  l2_torsion_reg=None,
                  l1_laplacian_reg=None,
                  l2_laplacian_reg=None):
  """Creates a lattice layer.

  Returns an output of lattice, lattice parameters, and projection ops.

  Args:
    input_tensor: [batch_size, input_dim] tensor.
    lattice_sizes: A list of lattice sizes of each dimension.
    is_monotone: A list of input_dim booleans, boolean or None. If None or
      False, lattice will not have monotonicity constraints. If
      is_monotone[k] == True, then the lattice output has the non-decreasing
      monotonicity with respect to input_tensor[?, k] (the kth coordinate). If
      True, all the input coordinate will have the non-decreasing monotonicity.
    output_dim: Number of outputs.
    interpolation_type: 'hypercube' or 'simplex'.
    lattice_initializer: (Optional) Initializer for lattice parameter vectors,
      a 2D tensor [output_dim, parameter_dim] (where parameter_dim ==
      lattice_sizes[0] * ... * lattice_sizes[input_dim - 1]). If None,
      lattice_param_as_linear initializer will be used with
      linear_weights=[1 if monotone else 0 for monotone in is_monotone].
    l1_reg: (float) l1 regularization amount.
    l2_reg: (float) l2 regularization amount.
    l1_torsion_reg: (float) l1 torsion regularization amount.
    l2_torsion_reg: (float) l2 torsion regularization amount.
    l1_laplacian_reg: (list of floats or float) list of L1 Laplacian
       regularization amount per each dimension. If a single float value is
       provided, then all diemnsion will get the same value.
    l2_laplacian_reg: (list of floats or float) list of L2 Laplacian
       regularization amount per each dimension. If a single float value is
       provided, then all diemnsion will get the same value.

  Returns:
    A tuple of:
    * output tensor of shape [batch_size, output_dim]
    * parameter tensor of shape [output_dim, parameter_dim]
    * None or projection ops, that must be applied at each
      step (or every so many steps) to project the model to a feasible space:
      used for bounding the outputs or for imposing monotonicity.
    * None or a regularization loss, if regularization is configured.

  Raises:
    ValueError: for invalid parameters.
  """
  if interpolation_type not in _VALID_INTERPOLATION_TYPES:
    raise ValueError('interpolation_type should be one of {}'.format(
        _VALID_INTERPOLATION_TYPES))

  if lattice_initializer is None:
    if is_monotone:
      is_monotone = tools.cast_to_list(is_monotone,
                                       len(lattice_sizes), 'is_monotone')
      linear_weights = [1.0 if monotonic else 0.0 for monotonic in is_monotone]
    else:
      linear_weights = [0.0] * len(lattice_sizes)
    lattice_initializer = lattice_param_as_linear(
        lattice_sizes, output_dim, linear_weights=linear_weights)

  parameter_tensor = variable_scope.get_variable(
      interpolation_type + '_lattice_parameters',
      initializer=lattice_initializer)

  output_tensor = lattice_ops.lattice(
      input_tensor,
      parameter_tensor,
      lattice_sizes,
      interpolation_type=interpolation_type)

  with ops.name_scope('lattice_monotonic_projection'):
    if is_monotone:
      is_monotone = tools.cast_to_list(is_monotone,
                                       len(lattice_sizes), 'is_monotone')
      projected_parameter_tensor = monotone_lattice(
          parameter_tensor,
          lattice_sizes=lattice_sizes,
          is_monotone=is_monotone)
      delta = projected_parameter_tensor - parameter_tensor
      projection_ops = [parameter_tensor.assign_add(delta)]
    else:
      projection_ops = None

  with ops.name_scope('lattice_regularization'):
    reg = regularizers.lattice_regularization(
        parameter_tensor,
        lattice_sizes,
        l1_reg=l1_reg,
        l2_reg=l2_reg,
        l1_torsion_reg=l1_torsion_reg,
        l2_torsion_reg=l2_torsion_reg,
        l1_laplacian_reg=l1_laplacian_reg,
        l2_laplacian_reg=l2_laplacian_reg)

  return (output_tensor, parameter_tensor, projection_ops, reg)


def ensemble_lattices_layer(input_tensor,
                            lattice_sizes,
                            structure_indices,
                            is_monotone=None,
                            output_dim=1,
                            interpolation_type='hypercube',
                            lattice_initializers=None,
                            l1_reg=None,
                            l2_reg=None,
                            l1_torsion_reg=None,
                            l2_torsion_reg=None,
                            l1_laplacian_reg=None,
                            l2_laplacian_reg=None):
  """Creates a ensemble of lattices layer.

  Returns a list of output of lattices, lattice parameters, and projection ops.

  Args:
    input_tensor: [batch_size, input_dim] tensor.
    lattice_sizes: A list of lattice sizes of each dimension.
    structure_indices: A list of list of ints. structure_indices[k] is a list
    of indices that belongs to kth lattices.
    is_monotone: A list of input_dim booleans, boolean or None. If None or
      False, lattice will not have monotonicity constraints. If
      is_monotone[k] == True, then the lattice output has the non-decreasing
      monotonicity with respect to input_tensor[?, k] (the kth coordinate). If
      True, all the input coordinate will have the non-decreasing monotonicity.
    output_dim: Number of outputs.
    interpolation_type: 'hypercube' or 'simplex'.
    lattice_initializers: (Optional) A list of initializer for each lattice
      parameter vectors. lattice_initializer[k] is a 2D tensor
      [output_dim, parameter_dim[k]], where parameter_dim[k] is the number of
      parameter in the kth lattice. If None, lattice_param_as_linear initializer
      will be used with
      linear_weights=[1 if monotone else 0 for monotone in is_monotone].
    l1_reg: (float) l1 regularization amount.
    l2_reg: (float) l2 regularization amount.
    l1_torsion_reg: (float) l1 torsion regularization amount.
    l2_torsion_reg: (float) l2 torsion regularization amount.
    l1_laplacian_reg: (list of floats or float) list of L1 Laplacian
       regularization amount per each dimension. If a single float value is
       provided, then all diemnsion will get the same value.
    l2_laplacian_reg: (list of floats or float) list of L2 Laplacian
       regularization amount per each dimension. If a single float value is
       provided, then all diemnsion will get the same value.

  Returns:
    A tuple of:
    * a list of output tensors, [batch_size, output_dim], with length
      len(structure_indices), i.e., one for each lattice.
    * a list of parameter tensors shape [output_dim, parameter_dim]
    * None or projection ops, that must be applied at each
      step (or every so many steps) to project the model to a feasible space:
      used for bounding the outputs or for imposing monotonicity.
    * None or a regularization loss, if regularization is configured.

  """
  num_lattices = len(structure_indices)
  lattice_initializers = tools.cast_to_list(lattice_initializers, num_lattices,
                                            'lattice initializers')
  if l1_laplacian_reg is not None:
    l1_laplacian_reg = tools.cast_to_list(l1_laplacian_reg,
                                          len(lattice_sizes),
                                          'l1_laplacian_reg')
  if l2_laplacian_reg is not None:
    l2_laplacian_reg = tools.cast_to_list(l2_laplacian_reg,
                                          len(lattice_sizes),
                                          'l2_laplacian_reg')
  # input_slices[k] = input_tensor[:, k].
  input_slices = array_ops.unstack(input_tensor, axis=1)

  output_tensors = []
  param_tensors = []
  projections = []
  regularization = None
  if is_monotone:
    is_monotone = tools.cast_to_list(is_monotone,
                                     len(lattice_sizes), 'is_monotone')
  # Now iterate through structure_indices to construct lattices.
  for (cnt, structure) in enumerate(structure_indices):
    with variable_scope.variable_scope('lattice_%d' % cnt):
      sub_lattice_sizes = [lattice_sizes[idx] for idx in structure]
      sub_is_monotone = None
      if is_monotone:
        sub_is_monotone = [is_monotone[idx] for idx in structure]

      sub_input_tensor_list = [input_slices[idx] for idx in structure]
      sub_input_tensor = array_ops.stack(sub_input_tensor_list, axis=1)

      if l1_laplacian_reg is not None:
        sub_l1_laplacian_reg = [l1_laplacian_reg[idx] for idx in structure]
      else:
        sub_l1_laplacian_reg = None

      if l2_laplacian_reg is not None:
        sub_l2_laplacian_reg = [l2_laplacian_reg[idx] for idx in structure]
      else:
        sub_l2_laplacian_reg = None

      packed_results = lattice_layer(
          sub_input_tensor,
          sub_lattice_sizes,
          sub_is_monotone,
          output_dim=output_dim,
          interpolation_type=interpolation_type,
          lattice_initializer=lattice_initializers[cnt],
          l1_reg=l1_reg,
          l2_reg=l2_reg,
          l1_torsion_reg=l1_torsion_reg,
          l2_torsion_reg=l2_torsion_reg,
          l1_laplacian_reg=sub_l1_laplacian_reg,
          l2_laplacian_reg=sub_l2_laplacian_reg)
      (sub_output, sub_param, sub_proj, sub_reg) = packed_results

      output_tensors.append(sub_output)
      param_tensors.append(sub_param)
      if sub_proj:
        projections += sub_proj
      regularization = tools.add_if_not_none(regularization, sub_reg)

  return (output_tensors, param_tensors, projections, regularization)
